# coding: utf-8

# -------------------------------------------------------------------------------
# Name:         baidu_ocr_client.py
# Description:  百度OCR客户端
# Author:       XiangjunZhao
# EMAIL:        2419352654@qq.com
# Date:         2020/3/16 18:54
# -------------------------------------------------------------------------------
import base64
import logging
import requests

logger = logging.getLogger(__name__)


class BaiduOcrClient(object):
    """
    百度OCR客户端
    """

    def __init__(self, api_key=None, secret_key=None):
        self.token_url = 'https://aip.baidubce.com/oauth/2.0/token'
        self.ocr_url = 'https://aip.baidubce.com/rest/2.0/ocr/v1/general_basic?access_token={}'
        self.api_key = api_key
        self.secret_key = secret_key
        self.access_token = ''
        self.__get_access_key()

    def __get_access_key(self):
        """
        获取百度access_token
        Returns:

        """
        params = {
            'grant_type': 'client_credentials',
            'client_id': self.api_key,
            'client_secret': self.secret_key
        }
        res = requests.post(url=self.token_url, params=params)
        if res.status_code == 200:
            result = res.json()
            if 'access_token' in result.keys() and 'scope' in result.keys():
                if not 'brain_all_scope' in result.get('scope').split(' '):
                    logger.info('请确认是否有OCR识别的权限')

                self.access_token = result.get('access_token')
                self.ocr_url = self.ocr_url.format(result.get('access_token'))
            else:
                logger.info('请提供正确的API_KEY和SECRET_KEY')
        else:
            logger.info('获取百度access_token失败')
            self.__get_access_key()

    @staticmethod
    def read_image_content(path=None):
        """
        读取图片内容
        Args:
            path: 图片路径

        Returns:

        """

        try:
            with open(file=path, mode='rb') as f:
                return f.read()
        except FileNotFoundError as e:
            logger.error(str(e))
            return None

    def get_ocr_result(self, data):
        """
        获取ocr识别结果
        Args:
            data: 图片内容

        Returns:

        """

        words = ''
        headers = {
            'Content-Type': 'application/x-www-form-urlencoded'
        }
        res = requests.post(url=self.ocr_url, headers=headers, data={'image': base64.b64encode(data)})
        if res.status_code == 200:
            result = res.json()
            words = '\n'.join([words_result.get('words', '') for words_result in result.get('words_result', [])])
        return words
